package sqltemplate

import (
	"fmt"
	"reflect"
	"strings"
	"text/template"
)

const (
	funcNamePrepare = "prepare"
	funcNameParam   = "param"
)

type SqlTemplate struct {
	queryArgs []interface{}      //sql占位符参数列表
	tpl       *template.Template //模板实例
}

func New(sql string) *SqlTemplate {
	result := new(SqlTemplate)
	tpl := template.New("default")
	//注册自定义函数
	tpl.Funcs(result.GetFuncMap())
	//解析sql模板
	tpl, err := tpl.Parse(sql)
	if err != nil {
		panic(err)
	}
	result.tpl = tpl
	return result
}

//返回自定义函数
func (s *SqlTemplate) GetFuncMap() template.FuncMap {
	return template.FuncMap{
		funcNamePrepare: s.prepare,
		funcNameParam:   s.param,
	}
}

//返回sql占位符参数列表
func (s SqlTemplate) QueryArgs() []interface{} {
	return s.queryArgs
}

//解析sql
//返回预编译的sql模板与占位符参数列表
func (s *SqlTemplate) Parse(params interface{}) (string, []interface{}) {
	var buf strings.Builder
	s.tpl.Execute(&buf, params)
	return buf.String(), s.queryArgs
}

//将入参替换为预编译占位符，slice类型自动展开
//eg.{{prepare .id}} ==> ?
//eg.{{prepare .idList}} ==> ?,?,?
func (s *SqlTemplate) prepare(value interface{}) string {
	refV := reflect.ValueOf(value)
	var b strings.Builder
	switch refV.Kind() {
	case reflect.Slice:
		//展开slice类型
		for i := 0; i < refV.Len(); i++ {
			if i == 0 {
				b.WriteString("?")
			} else {
				b.WriteString(",?")
			}
			//记录占位符参数
			s.queryArgs = append(s.queryArgs, refV.Index(i).Interface())
		}
		return b.String()
	default:
		//记录占位符参数
		s.queryArgs = append(s.queryArgs, value)
		return "?"
	}
	return ""
}

//将入参替换为字符串
//eg.’{{param .const}}‘ ==> 'const's value'
func (s *SqlTemplate) param(value interface{}) string {

	return fmt.Sprintf("%v", value)
}
